import matplotlib.pyplot as pyplot
import numpy
from matplotlib import gridspec
import scipy.integrate
import scipy.optimize
import time
import warnings
from jqc import jqc_plot
import os


cwd = os.path.dirname(os.path.abspath(__file__))
jqc_plot.plot_style('normal')
warnings.filterwarnings("ignore")

########################### Set Global Parameters ######################################
T0 = 2.6e-6 #Kelvin
w = 2*numpy.pi*(170*170*170)**(1/3) #Geometric average of trap frequency, Hz
m = 220*1.660539e-27 #kg
kB = 1.38065e-23 #J/K

Ndat = 1e3

K2 = 4.8e-11
K2 = 5.31e-11

STIRAPEFFICIENCY = 0.89

PlotTwoBodyBackground, TwoBodyBkgdParams = True, True
PlotTwoBodyOnly, TwoBodyOnlyParams = False, True
fit_first = False

ConstantTemperature = False

ShowChiSquare = True #Don't turn this off!
print(' ')
###################### Define Necessary Fit Functions ##################################

def GetAveragedData(Data):
    #Sort data by time (first) column
    Data = Data[numpy.argsort(Data[:,0])]

    SectDat = Data[0,:] #Initilise arrays to store future data
    AveragedData = numpy.zeros(3)

    for i in range(1, len(Data[:,0])):
        if Data[i,0] == Data[i-1,0]:
            SectDat = numpy.vstack((SectDat, Data[i,:])) #Array with all same times
        else:
            #When time is about to change, average number over that time

            AverageN = numpy.average(SectDat[:,1])
            ErrN = numpy.std(SectDat[:,1])/numpy.sqrt(len(SectDat[:,1]))

            AveragedData = numpy.vstack((AveragedData, numpy.array([SectDat[0,0], AverageN, ErrN])))
            SectDat = Data[i,:] #Reinitialise same time data store

    #Need to get data from last timestep
    AverageN = numpy.average(SectDat[:,1])
    ErrN = numpy.std(SectDat[:,1])/numpy.sqrt(len(SectDat[:,1]))

    AveragedData = numpy.vstack((AveragedData, numpy.array([SectDat[0,0], AverageN, ErrN])))
    AveragedData = numpy.delete(AveragedData, 0, axis=0) #Delete initial row of zeros
    return AveragedData

def TwoBodyBkgd(Input, t, a, b):
    N, T = Input
    a = abs(a)
    b = abs(b)
    dN = -N*a - b*(N**2)/(T**(3/2))
    dT = (b/4)*(N/(numpy.sqrt(T)))
    Output = (dN, dT)
    return Output


def TwoBody(Input, t, b):
    N, T = Input
    b = abs(b)
    dN = - b*(N**2)/(T**(3/2))
    dT = (b/4)*(N/(numpy.sqrt(T)))
    if ConstantTemperature == True:
        dT = 0
    Output = (dN, dT)
    return Output

def TwoBodyBkgdFitFunction(xData, N0, a, b):
    #Ndat = 1e3
    Init = N0, T0
    Times = numpy.zeros(1)
    for i in range(1, len(xData[:])):
        if xData[i] != xData[i-1]:
            StarT = abs(xData[i] - xData[i-1])/Ndat
            Tim = numpy.linspace(xData[i-1]+StarT, xData[i], Ndat)
            Times = numpy.hstack((Times, Tim))

    #CHANGE THIS LINE TO CHANGE FIT FUNCTION!!
    FitLine = scipy.integrate.odeint(TwoBodyBkgd, Init, t = Times, args=(a,b))

    yData = FitLine[0::int(Ndat), 0]
    yData = numpy.array(FitLine[0, 0])
    it = 0
    for i in range(1, len(xData[:])):
        if xData[i] != xData[i-1]:
            it = it+1
        yData = numpy.hstack((yData, FitLine[int(it*Ndat), 0]))
    return yData

def TwoBodyFitFunction(xData, N0, b):
    #Ndat = 1e3
    Init = N0, T0
    Times = numpy.zeros(1)
    for i in range(1, len(xData[:])):
        if xData[i] != xData[i-1]:
            StarT = abs(xData[i] - xData[i-1])/Ndat
            Tim = numpy.linspace(xData[i-1]+StarT, xData[i], Ndat)
            Times = numpy.hstack((Times, Tim))

    #CHANGE THIS LINE TO CHANGE FIT FUNCTION!!
    FitLine = scipy.integrate.odeint(TwoBody, Init, t = Times, args=(b,))

    yData = FitLine[0::int(Ndat), 0]
    yData = numpy.array(FitLine[0, 0])
    it = 0
    for i in range(1, len(xData[:])):
        if xData[i] != xData[i-1]:
            it = it+1
        yData = numpy.hstack((yData, FitLine[int(it*Ndat), 0]))
    return yData

def CalcK2(ValueRaw, ErrorRaw):
    Value = ValueRaw*(((4*numpy.pi*kB)/(m*(w**2)))**1.5) #m^3 s^-1
    Error = ErrorRaw*(((4*numpy.pi*kB)/(m*(w**2)))**1.5) #m^3 s^-1
    return Value*1e6, Error*1e6 #cm^3 s^-1

def Calcb(ValueRaw):
    Value = ValueRaw/(((4*numpy.pi*kB)/(m*(w**2)))**1.5) #m^3 s^-1
    return Value*1e-6


def ErrorRound(Value, Error):
    if Error == 0 or Error == float("inf") or Error == float("-inf") or Error == float("nan"):
        Error = 0
        Value = Value
    else:
        Error = round(abs(Error), -int(numpy.floor(numpy.log10(abs(Error)))))
        Value = round(Value, -int(numpy.floor(numpy.log10(abs(Error)))))

    return Value, Error

def ChiSquared(Exp, Obs, *args):
    Tot = []
    if args:
        Err = args[0]
        for i in range(len(Exp)):
            Tot.append(((Exp[i]-Obs[i])**2)/(Err[i])**2)
    else:
        for i in range(len(Exp)):
            Tot.append(((Exp[i]-Obs[i])**2)/Exp[i])
    ChiSquared = sum(Tot)/len(Tot)
    return ChiSquared

def CalcKSticky(ValueRaw, ErrorRaw):
    Veff = ((m*(w**2))/(2*numpy.pi*kB*T0))**1.5
    Value = ValueRaw/Veff
    Error = ErrorRaw/Veff #m^3 s^-1
    return Value*1e6, Error*1e6 #cm^6 s^-1

############################### JQC Colours ###################################
JQC = jqc_plot.colours
########################## Initialise the plot ################################

#Define Figure Space
Scaling = 0.82
FigRatio = numpy.array((10, 8.25))

#Specify grid - No. of rows, No. of columns
gs = gridspec.GridSpec(1,1,
                       width_ratios=[1],
                       height_ratios=[1]
                       )

Intensities=[0.,0.3,0.6,0.9,1.2,1.4,1.6,1.8]
Intensities=[20.]

output = numpy.zeros((len(Intensities),6))
b =Calcb(K2)
########################## Fit Two Body Collisions (w/BKGD) ###################
for i,I in enumerate(Intensities):

    start_time = time.time()
    filename = "I={:.1f}".format(I)
    fig = pyplot.figure(filename,figsize=FigRatio*Scaling, dpi=100)
    sub = fig.add_subplot(gs[0])

    #Set axis labels
    sub.set_ylabel("Molecule Number", fontsize=18)
    sub.set_xlabel("$t$ (s)", fontsize=18)

    inset =fig.add_axes([.6, .62, 0.35, 0.3], facecolor='w')

    #Set axis labels
    inset.set_ylabel("T ($\mathrm{\mu}$K)", fontsize=18)
    inset.set_xlabel("$t$ (s)", fontsize=18)
    TempScaler = []

    ########################### Plot the Lifetime Experimental Data ###########
    #Get raw data and plot scatter
    Data = numpy.genfromtxt(filename+".csv",delimiter=",")

    Data[:, 1] = Data[:, 1]/STIRAPEFFICIENCY

    #Data format - First column is time in units of s, second column is molecule number
    #sub.errorbar(Data[:,0],Data[:,1],fmt='o',markeredgewidth=1,markeredgecolor='gray',\
    #             markerfacecolor='white',capsize=3.5,ecolor='gray', zorder=0)

    #Generate error bars and average points for data
    AveragedData = GetAveragedData(Data)

    sub.errorbar(AveragedData[:,0],AveragedData[:,1],yerr=AveragedData[:,2],
                fmt='o',markeredgewidth=1,markeredgecolor='black',
                markerfacecolor='white',capsize=3.5,ecolor='black', zorder=20)



    #Sensible scaling of axes to fit the data
    sub.set_xlim(min(Data[:,0])-0.01*max(Data[:,0]), 1.05*max(Data[:,0]))
    sub.set_ylim(0, 1.01*max(Data[:,1]))
    inset.set_xlim(min(Data[:,0])-0.01*max(Data[:,0]), 1.05*max(Data[:,0]))


    #Guess initial parameters : N0, a, b, Th
    if i ==0 and fit_first:
        x0 = numpy.array([2500,1e-6,b])
        fitfn = lambda x,N0,a,b: TwoBodyBkgdFitFunction(x,N0,a,b)
        popt2, pcov2 = scipy.optimize.curve_fit(fitfn, Data[:,0], Data[:,1], p0=x0)
        N0, a,b = abs(popt2)
        Init = N0, T0
    else:
        x0 = numpy.array([2500,1e-6])
        fitfn = lambda x,N0,a: TwoBodyBkgdFitFunction(x,N0,a,b)
        popt2, pcov2 = scipy.optimize.curve_fit(fitfn, Data[:,0], Data[:,1], p0=x0)
        N0, a = abs(popt2)
        Init = N0, T0

    FitTimes = numpy.linspace(0, 1.1*max(Data[:,0]), 1e6)

    Fit2BodyLine = scipy.integrate.odeint(TwoBodyBkgd, Init, t = FitTimes,
                                            args=(a,b))

    if ShowChiSquare ==True:
        Exp = TwoBodyBkgdFitFunction(Data[:,0], N0, a, b)
        ChiSquareI = ChiSquared(Exp, Data[:,1])
        Exp = TwoBodyBkgdFitFunction(AveragedData[:,0], N0, a, b)
        ChiSquareE = ChiSquared(Exp, AveragedData[:,1], AveragedData[:,2])

    sub.plot(FitTimes, Fit2BodyLine[:,0], color=JQC['red'],
                label='2-Body + 1-Body, $\chi^2$ = %s'%round(ChiSquareE, 1))

    inset.plot(FitTimes, Fit2BodyLine[:,1]*1e6, color=JQC['red'],
                label='2-Body + 1-Body')

    TempScaler.append(max(Fit2BodyLine[:,1]))

    if TwoBodyBkgdParams == True:
        ValueRaw = abs(popt2[0])
        ErrorRaw = abs(numpy.sqrt(pcov2[0][0]))
        Value, Error = ErrorRound(ValueRaw, ErrorRaw)
        print('2 Body Fit + Background Parameters:\n',
                'N0 =', Value, '+/-', Error,'\n')
        output[i,0] = ValueRaw
        output[i,1] = ErrorRaw

        ValueRaw = abs(popt2[1])
        ErrorRaw = abs(numpy.sqrt(pcov2[1][1]))
        Value, Error = ErrorRound(ValueRaw, ErrorRaw)
        print('', "Tau (a) =", Value, '+/-', Error, 's\n')

        output[i,2] = ValueRaw
        output[i,3] = ErrorRaw

        ValueRaw = abs(b)
        if len(popt2)==3:
            ErrorRaw = abs(numpy.sqrt(pcov2[2][2]))
        else:
            ErrorRaw = 0
        ValueRaw, ErrorRaw = CalcK2(ValueRaw, ErrorRaw)
        Value, Error = ErrorRound(ValueRaw, ErrorRaw)
        print('Chi Square =', ChiSquareE, ' K2 (b) =',
                Value, '+/-', Error, 'cm^3 s^-1')

        output[i,4] = ValueRaw
        output[i,5] = ErrorRaw

        print('---------------------------------------------------')



    inset.set_ylim(0.99*T0*1e6, 1.01*max(TempScaler)*1e6)
    fig.tight_layout()
    print ("Took", time.time() - start_time, "seconds to find solution")
    print('---------------------------------------------------')

    fig.savefig(cwd+"\\Figures\\"+filename+"inc2body.pdf", dpi = 600)

numpy.savetxt("Lifetimes_pars_inc2body.csv",output,delimiter=',')

#
pyplot.show()
#pyplot.close()
